import os.path
from typing import NamedTuple, Dict

import torch

from centralized_verification.paths import CHECKPOINT_DIR
from experiments.utils.parallel_run import DEVICE


class TrainingState(NamedTuple):
    global_step_count: int
    global_episode_count: int
    learner_state_dict: Dict

    def save(self, filename):
        dir_path = os.path.dirname(filename)
        try:
            os.mkdir(dir_path)
        except FileExistsError:
            # We didn't need to do anything
            pass

        torch.save({
            "global_step_count": self.global_step_count,
            "global_episode_count": self.global_episode_count,
            "learner_state": self.learner_state_dict
        }, filename)

    @staticmethod
    def load(filename, map_location: torch.device = "cpu"):
        d = torch.load(filename, map_location)
        return TrainingState(
            global_step_count=d["global_step_count"],
            global_episode_count=d["global_episode_count"],
            learner_state_dict=d["learner_state"]
        )


def maybe_load_from_checkpoint(run_name):
    path = f"{CHECKPOINT_DIR}/{run_name}.pt"
    if os.path.exists(path):
        return TrainingState.load(path, DEVICE)
    else:
        return None
